Transfer Learning - Disney Image Classification

by CM


Posted on April 30, 2020



The Goal:

In this article, we will explore the magic of Transfer Learning (TL). In particular, we will build a dataset of Disney Princesses and try to predict what Disney Princess someone most likely is. As we plan not to spend to much time, collecting Disney image data for our image set, we will use a pretrained model. In detail, we will build a base model from the MobileNet V2 model developed at Google. This is pre-trained on the ImageNet dataset, a large dataset consisting of 1.4M images and 1000 classes. ImageNet is a research training dataset with a wide variety of categories like jackfruit and syringe. Based on this base-model, we will add our classification layer for the Disney princesses. The outcome will be a Convolutional Neural Network (CNN). Let's see how good we will do.

Image classification:

Image classification is a supervised learning problem. We define a set of target classes (in our case Disney Princesses), and train a model to recognize them using labeled example photos. In our example, we will make use of TensorFlow 2.x in order to build, train, and optimize our model.


Key components are:

Let's jump right into the Code. First, we import all required dependencies. (1) TensorFlow: is a free and open-source software library for dataflow and differentiable programming across a range of tasks. It is a symbolic math library, and is also used for machine learning applications such as neural networks. (2) Keras: is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research. (3) Numpy: is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. (4) MatPlotLib: is a plotting library for the Python programming language and its numerical mathematics extension NumPy. It provides an object-oriented API for embedding plots into applications using general-purpose GUI toolkits like Tkinter, wxPython, Qt, or GTK+. (5) os: This module provides a portable way of using operating system dependent functionality. (6) Zipfile: The ZIP file format is a common archive and compression standard. This module provides tools to create, read, write, append, and list a ZIP file. Any advanced use of this module will require an understanding of the format, as defined in PKZIP Application Note.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import numpy as np
import matplotlib.pyplot as plt

import os
import zipfile

Second, we will load our data into Colab from Google Drive. In this regard, we need to mount Google Drive and authenticate ourselves in order to access data from the cloud. In my case, I stored the Disney Dataset in "/content/gdrive/My Drive/Datasets/". The file name is "disney_princesses_dataset.zip". Surely you can also upload the dataset manually in Colab or use any other storage solution.

#Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

#We will create a temp. directory where we gonna store the picture data
!mkdir disney4
#Go to directory (not necessary in case you specify the extraction path directly via !unzip)
os.chdir('disney4')

#Unzip our .zip file in the directory
!unzip "/content/gdrive/My Drive/Datasets/disney_princesses_dataset.zip" -d 'disney4'

#Path where we have stored our Pictures in Colab
PATH = 'disney4'

After entering our authentification code directories will be created and images extracted. Below is a snapshot on how your output should look like.

==========================
OUTPUT
==========================

Enter your authorization code:
··········
Mounted at /content/gdrive
Archive:  /content/gdrive/My Drive/Datasets/pics4.zip
   creating: disney4/train/
   creating: disney4/train/Anna/
  inflating: disney4/train/Anna/Anna.1.jpg
  inflating: disney4/train/Anna/Anna.10.jpg
  inflating: disney4/train/Anna/Anna.2.jpg
  inflating: disney4/train/Anna/Anna.3.jpg
 extracting: disney4/train/Anna/Anna.4.jpg
 extracting: disney4/train/Anna/Anna.5.jpg
  inflating: disney4/train/Anna/Anna.6.jpg
  inflating: disney4/train/Anna/Anna.7.jpg
  inflating: disney4/train/Anna/Anna.8.jpg
  inflating: disney4/train/Anna/Anna.9.jpg
.........

Next, we will specify the training and validation directory. This can easily be done by extracting the folder names of the dataset (note we have one train and one validation folder). We therefore make use of the path.join function to create the variables of the respective directories.

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

We then set up our variables that we will use while pre-processing the dataset and training the network. As we are only having a small dataset and try not to overfit our model right away - I recommend going for a max. of 50 epochs with a batch size of 5. In addition, we will rescale our pictures to 150x150 pixels as we plan to use a 1D [150,150] Tensor.

batch_size = 5
epochs = 50
IMG_HEIGHT = 150
IMG_WIDTH = 150

Next step is data preparation. We will format the images into appropriately pre-processed floating point tensors before feeding to the network. Therefore, we ill decode contents of these images and convert it into proper grid format as per their RGB content. After that we will convert them into floating point tensors. Finally, we will rescale the tensors from values between 0 and 255 to values between 0 and 1, as neural networks prefer to deal with small input values. Therefore, we will use the `ImageDataGenerator` class provided by `tf.keras`. It can read images from disk and preprocess them into proper tensors. It will also set up generators that convert these images into batches of tensors—helpful when training the network.

train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our validation data

After defining the generators for training and validation images, the flow_from_directory method load images from the disk, applies rescaling, and resizes the images into the required dimensions.

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                         directory=train_dir,
                                                         shuffle=True,
                                                          target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                          class_mode='categorical')
#Found 145 images belonging to 14 classes.


val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                            directory=validation_dir,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='categorical')
#Found 14 images belonging to 14 classes.

Let's have a look at our classes.

labels = (train_data_gen.class_indices)
labels = dict((v,k) for k,v in labels.items())
print(labels)

This worked perfectly. We should now see a dictionary of our classes -- the 14 Disney princesses that are in our dataset.

{0: 'Anna', 1: 'Ariel', 2: 'Aurora', 3: 'Belle', 4: 'Cats', 5: 'Cinderella', 6: 'Elsa', 7: 'Jasmine', 8: 'Merida', 9: 'Moana', 10: 'Mulan', 11: 'Rapunzel', 12: 'Snow', 13: 'Tiana'}

Now let's visualize the training images by extracting a batch of images from the training generator and then plot five of them with matplotlib. I hope you can still recognize the images after rescaling and normalizing the pixel values.

sample_training_images, _ = next(train_data_gen)

# This function will plot images in the form of a grid with 1 row and 5 columns where images are placed in each column.
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

plotImages(sample_training_images[:5])

See below our five sample images that are in our training dataset. Remember, we have about 145 training images + 14 validation images.


Now it is time to create our model. As mentioned above, we will leverage the pretrained MobileNetV2 model by Google for our base model, due to the fact that we are not having a vast amount of training data for our Disney Princesses. This approach is called transfer learning and is especially valuable in cases where not much data is present for training purposes.

IMG_SHAPE = (IMG_HEIGHT, IMG_WIDTH, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

Downloading the model can take some time as we are working with a 9+ MB model. The model comes in an HDF5 file. HDF5 is a unique technology suite that makes possible the management of extremely large and complex data collections. The HDF5 technology suite includes: A versatile data model that can represent very complex data objects and a wide variety of metadata.

==========================
OUTPUT
==========================

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

As we do not want to retrain the base model (MobileNetV2), we gonna exclude it from the training process by setting the trainable argument to false.

base_model.trainable = False

Let's take a look at the base model architecture. Although it looks family complex -- most of the model leverages standard components such as ConvLayer, Normalization and respective activation functions to ensure non-linearity.

base_model.summary()

==========================
OUTPUT
==========================

Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 150, 150, 3) 0
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 151, 151, 3)  0           input_1[0][0]
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 75, 75, 32)   864         Conv1_pad[0][0]
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 75, 75, 32)   128         Conv1[0][0]
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 75, 75, 32)   0           bn_Conv1[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 75, 75, 32)   288         Conv1_relu[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 75, 75, 32)   128         expanded_conv_depthwise[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 75, 75, 32)   0           expanded_conv_depthwise_BN[0][0]
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 75, 75, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 75, 75, 16)   64          expanded_conv_project[0][0]
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 75, 75, 96)   1536        expanded_conv_project_BN[0][0]
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 75, 75, 96)   384         block_1_expand[0][0]
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 75, 75, 96)   0           block_1_expand_BN[0][0]
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 77, 77, 96)   0           block_1_expand_relu[0][0]
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 38, 38, 96)   864         block_1_pad[0][0]
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 38, 38, 96)   384         block_1_depthwise[0][0]
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 38, 38, 96)   0           block_1_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 38, 38, 24)   2304        block_1_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 38, 38, 24)   96          block_1_project[0][0]
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 38, 38, 144)  3456        block_1_project_BN[0][0]
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 38, 38, 144)  576         block_2_expand[0][0]
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 38, 38, 144)  0           block_2_expand_BN[0][0]
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 38, 38, 144)  1296        block_2_expand_relu[0][0]
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 38, 38, 144)  576         block_2_depthwise[0][0]
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 38, 38, 144)  0           block_2_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 38, 38, 24)   3456        block_2_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 38, 38, 24)   96          block_2_project[0][0]
__________________________________________________________________________________________________
block_2_add (Add)               (None, 38, 38, 24)   0           block_1_project_BN[0][0]
                                                                 block_2_project_BN[0][0]
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 38, 38, 144)  3456        block_2_add[0][0]
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 38, 38, 144)  576         block_3_expand[0][0]
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 38, 38, 144)  0           block_3_expand_BN[0][0]
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 39, 39, 144)  0           block_3_expand_relu[0][0]
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 19, 19, 144)  1296        block_3_pad[0][0]
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 19, 19, 144)  576         block_3_depthwise[0][0]
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 19, 19, 144)  0           block_3_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 19, 19, 32)   4608        block_3_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 19, 19, 32)   128         block_3_project[0][0]
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 19, 19, 192)  6144        block_3_project_BN[0][0]
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 19, 19, 192)  768         block_4_expand[0][0]
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 19, 19, 192)  0           block_4_expand_BN[0][0]
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 19, 19, 192)  1728        block_4_expand_relu[0][0]
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 19, 19, 192)  768         block_4_depthwise[0][0]
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 19, 19, 192)  0           block_4_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 19, 19, 32)   6144        block_4_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 19, 19, 32)   128         block_4_project[0][0]
__________________________________________________________________________________________________
block_4_add (Add)               (None, 19, 19, 32)   0           block_3_project_BN[0][0]
                                                                 block_4_project_BN[0][0]
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 19, 19, 192)  6144        block_4_add[0][0]
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 19, 19, 192)  768         block_5_expand[0][0]
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 19, 19, 192)  0           block_5_expand_BN[0][0]
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 19, 19, 192)  1728        block_5_expand_relu[0][0]
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 19, 19, 192)  768         block_5_depthwise[0][0]
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 19, 19, 192)  0           block_5_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 19, 19, 32)   6144        block_5_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 19, 19, 32)   128         block_5_project[0][0]
__________________________________________________________________________________________________
block_5_add (Add)               (None, 19, 19, 32)   0           block_4_add[0][0]
                                                                 block_5_project_BN[0][0]
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 19, 19, 192)  6144        block_5_add[0][0]
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 19, 19, 192)  768         block_6_expand[0][0]
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 19, 19, 192)  0           block_6_expand_BN[0][0]
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]
                                                                 block_7_project_BN[0][0]
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]
                                                                 block_8_project_BN[0][0]
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]
                                                                 block_9_project_BN[0][0]
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]
                                                                 block_11_project_BN[0][0]
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]
                                                                 block_12_project_BN[0][0]
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]
                                                                 block_14_project_BN[0][0]
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]
                                                                 block_15_project_BN[0][0]
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984

Let's define the feature batch shape for our layer that will sit on top of the base model. We could also call it base model output.

final_img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

final_img_tfl = np.expand_dims(final_img, axis=0)
print(final_img_tfl.shape)

feature_batch = base_model(final_img_tfl)
print(feature_batch.shape)

==========================
OUTPUT
==========================

(1, 150, 150, 3)
(1, 5, 5, 1280)

We will build a GlobalAveragePooling layer on top of the base model. Remember, we currently have a feature output share of (1, 5, 5, 1280). However, for classification of 14 classes, we just want to have a (1, 14) Tensor.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

prediction_layer = tf.keras.layers.Dense(units = 14, input_shape = (520,), activation='softmax')
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

==========================
OUTPUT
==========================

(1, 1280)
(1, 14)

Next step is compiling our model. For our model, we choose the ADAM optimizer and categorical cross entropy loss function. To view training and validation accuracy for each training epoch, pass the metrics argument.

model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Let's have a look at our complete model now.

model.summary()

Let's hope that these 17934 trainable parameters are enough to allow us to get a decent accuracy for our Disney Princess predictions.

==========================
OUTPUT
==========================

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
mobilenetv2_1.00_224 (Model) (None, 5, 5, 1280)        2257984
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0
_________________________________________________________________
dense (Dense)                (None, 14)                17934
=================================================================
Total params: 2,275,918
Trainable params: 17,934
Non-trainable params: 2,257,984
_________________________________________________________________

Let the training begin:

history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=5,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=5
)

==========================
OUTPUT
==========================

Epoch 1/50
5/5 [==============================] - 2s 368ms/step - loss: 2.6474 - accuracy: 0.0800 - val_loss: 2.6210 - val_accuracy: 0.1250
Epoch 2/50
5/5 [==============================] - 1s 146ms/step - loss: 2.5233 - accuracy: 0.2800 - val_loss: 2.5921 - val_accuracy: 0.1667
Epoch 3/50
5/5 [==============================] - 1s 165ms/step - loss: 2.5650 - accuracy: 0.1600 - val_loss: 2.5943 - val_accuracy: 0.1250
Epoch 4/50
5/5 [==============================] - 1s 150ms/step - loss: 2.5943 - accuracy: 0.1200 - val_loss: 2.5894 - val_accuracy: 0.1250
Epoch 5/50
5/5 [==============================] - 1s 158ms/step - loss: 2.5595 - accuracy: 0.2000 - val_loss: 2.5373 - val_accuracy: 0.2917
.............
Epoch 45/50
5/5 [==============================] - 1s 137ms/step - loss: 1.8118 - accuracy: 0.9600 - val_loss: 2.1941 - val_accuracy: 0.5833
Epoch 46/50
5/5 [==============================] - 1s 159ms/step - loss: 1.7876 - accuracy: 1.0000 - val_loss: 2.1710 - val_accuracy: 0.6250
Epoch 47/50
5/5 [==============================] - 1s 146ms/step - loss: 1.8405 - accuracy: 0.9200 - val_loss: 2.2322 - val_accuracy: 0.5417
Epoch 48/50
5/5 [==============================] - 1s 155ms/step - loss: 1.7789 - accuracy: 1.0000 - val_loss: 2.1696 - val_accuracy: 0.5833
Epoch 49/50
5/5 [==============================] - 1s 133ms/step - loss: 1.7679 - accuracy: 1.0000 - val_loss: 2.2126 - val_accuracy: 0.5417
Epoch 50/50
5/5 [==============================] - 1s 158ms/step - loss: 1.8199 - accuracy: 0.9600 - val_loss: 2.1558 - val_accuracy: 0.6250

Ok, 50 Epochs of training with 96% training accuracy and 63% validation accuracy -- quite impressive with only 10 training images for each Disney Princess. Let's visualize our training and validation progress.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss=history.history['loss']
val_loss=history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


Cool, we have build a decent model while leveraging transfer learning. Let's give it a shot and throw some Disney images at it in order to see how it is doing...

image_path = "/content/gdrive/My Drive/Datasets"

def loadImages(path):
    '''Put files into lists and return them as one list with all images
     in the folder'''
    image_file = sorted([os.path.join(path, file)
                          for file in os.listdir(path )
                          if file.endswith('.JPG')])
    return image_file

image_list = loadImages(image_path)
print(image_list)

path = np.array(image_list)
path_string = (path[0])

print(path_string)

img = tf.io.read_file(path_string)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
final_img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])


plt.subplot(121), plt.imshow(final_img)



#Expand Tensor for Model (Input shape)
y = np.expand_dims(final_img, axis=0)

#Predict Image Tensor with model
prediction = model.predict(y)
prediction_squeeze = np.squeeze(prediction, axis=0)

label_array = np.array(labels)

#print(type(label))
for key, value in labels.items():
    real_label = prediction_squeeze[key]

    print ("{0:.0%}".format(real_label), value)

This is what we hope for -- we got an Elsa :).

==========================
OUTPUT
==========================

0% Anna
0% Ariel
0% Aurora
0% Belle
2% Cinderella
91% Elsa
0% Jasmine
0% Merida
2% Moana
0% Mulan
0% Pocahontas
3% Rapunzel
0% Snow
1% Tiana

Yes, we did it! A cool application of Transfer Learning (using MobileNetV2) for image classification. Especially easy to handle with the build in Keras functions.

Transfer Learning

#EpicML


News
Dec 2021

--- Quantum ---

Simulating matter on the quantum scale with AI #Deepmind
Nov 2021

--- Graviton3 ---

Amazon announced its Graviton3 processors for AI inferencing - the next generation of its custom ARM-based chip for AI inferencing applications. #Graviton3
May 2021

--- Vertex AI & TPU Gen4. ---

Google announced its fourth generation of tensor processing units (TPUs) for AI and ML workloads and the Vertex AI managed platform #VertexAI #TPU
Feb 2021

--- TensorFlow 3D ---

In February of 2021, Google released TensorFlow 3D to help enterprises develop and train models capable of understanding 3D scenes #TensorFlow3D
Nov 2020

--- AlphaFold ---

In November of 2020, AlphaFold 2 was recognised as a solution to the protein folding problem at CASP14 #protein_folding
Oct 2019

--- Google Quantum ---

A research effort from Google AI that aims to build quantum processors and develop novel quantum algorithms to dramatically accelerate computational tasks for machine learning. #quantum_supremacy
Oct 2016

--- AlphaGo ---

Mastering the game of Go with Deep Neural Networks. #neural_network